import tensorflow as tf
import os
import numpy as np
import re
import matplotlib.pyplot as plt
from PIL import Image

class NodeLookup(object):
    def __init__(self):
        label_lookup_path = 'imagenet_2012_challenge_label_map_proto.pbtxt'
        uid_lookup_path = 'imagenet_synset_to_human_label_map.txt'
        self.node_lookup = self.load(label_lookup_path,uid_lookup_path)

    def load(self,label_lookup_path,uid_lookup_path):
        # 加载分类字符串
        proto_as_ascii_lines = tf.gfile.GFile(uid_lookup_path).readlines()
        uid_to_human ={}
        # 一行一行读
        for line in proto_as_ascii_lines:
            # 去掉换行符
            line = line.strip('\n')
            # 按照\t 分割
            parse_items = line.split('\t')
            # 获取分类编号
            uid = parse_items[0]
            # 分类名称
            human_string = parse_items[1]
            # 保存映射
            uid_to_human[uid] = human_string

        # 加载分类字符串 分类编号1-1000
        proto_as_ascii = tf.gfile.GFile(label_lookup_path).readlines()
        node_id_to_uid = {}
        for line in proto_as_ascii:
            if line.startswith('    target_class:'):
                # 获取分类1-1000
                taget_class = int(line.split(': ')[1])
            if line.startswith('    target_class_string:'):
                # 获取字符串
                taget_class_string = line.split(': ')[1]
                # 映射
                node_id_to_uid[taget_class] = taget_class_string[1:-2]

        node_id_to_name = {}
        for key,val in node_id_to_uid.items():
            # 获取分类名
            name = uid_to_human[val]
            # 建立分类映射关系
            node_id_to_name[key] = name
        return node_id_to_name

    # 传入分类编号1-1000返回分类名称
    def id_to_string(self,node_id):
        if node_id not in self.node_lookup:
            return
        return  self.node_lookup[node_id]

# 创建一个图来存放训练好的模型
with tf.gfile.GFile('D:\\Users\\kai\\PycharmProjects\\python_mat\\tensorFlow——\\inception_model\\classify_image_graph_def.pb','rb')as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def,name='')

with tf.Session() as sess:
    softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')
    # 遍历目录
    for root,dirs,files in os.walk('images/'):
        for file in files:
            # 载入图
            image_data = tf.gfile.GFile(os.path.join(root,file),'rb').read()
            predictions = sess.run(softmax_tensor,{'DecodeJpeg/contents:0':image_data})
            # 结果转1维
            predictions = np.squeeze(predictions)

            image_path = os.path.join(root,file)
            print(image_path)

            # 显示
            img = Image.open(image_path)
            plt.imshow(img)
            plt.axis('off')
            plt.show()

            # 排序
            top_k = predictions.argsort()[-5:][::-1]
            node_lookup = NodeLookup()
            for node_id in top_k:
                # 名称
                human_string = node_lookup.id_to_string(node_id)
                # 置信度
                score = predictions[node_id]
                print('%s(score = %.5f)'%(human_string,score))
            print()